{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4620bc8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip3 install gpt-2-simple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "477a8024",
   "metadata": {},
   "outputs": [],
   "source": [
    "from gpt_2_simple.src import model as gpt2_model, encoder\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9528fd8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = '117m-hparams.json'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc10d606",
   "metadata": {},
   "outputs": [],
   "source": [
    "hparams = gpt2_model.default_hparams()\n",
    "with open(params) as f:\n",
    "    hparams.override_from_dict(json.load(f))\n",
    "\n",
    "with open('encoder.json', 'r') as f:\n",
    "    en = json.load(f)\n",
    "with open('vocab.bpe', 'r', encoding=\"utf-8\") as f:\n",
    "    bpe_data = f.read()\n",
    "    \n",
    "bpe_merges = [\n",
    "    tuple(merge_str.split()) for merge_str in bpe_data.split('\\n')[1:-1]\n",
    "]\n",
    "enc_malay = encoder.Encoder(encoder=en, bpe_merges=bpe_merges)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04232aaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "def top_k_logits(logits, k):\n",
    "\n",
    "    def _top_k():\n",
    "        values, _ = tf.nn.top_k(logits, k=k)\n",
    "        min_values = values[:, -1, tf.newaxis]\n",
    "        return tf.where(\n",
    "            logits < min_values,\n",
    "            tf.ones_like(logits, dtype=logits.dtype) * -1e10,\n",
    "            logits,\n",
    "        )\n",
    "\n",
    "    return tf.cond(\n",
    "        pred=tf.equal(k, 0),\n",
    "        true_fn=lambda: logits,\n",
    "        false_fn=lambda: _top_k(),\n",
    "    )\n",
    "\n",
    "\n",
    "def top_p_logits(logits, p):\n",
    "    with tf.variable_scope('top_p_logits'):\n",
    "        logits_sort = tf.sort(logits, direction='DESCENDING')\n",
    "        probs_sort = tf.nn.softmax(logits_sort)\n",
    "        probs_sums = tf.cumsum(probs_sort, axis=1, exclusive=True)\n",
    "        logits_masked = tf.where(\n",
    "            probs_sums < p, logits_sort, tf.ones_like(logits_sort) * 1000\n",
    "        )\n",
    "        min_logits = tf.reduce_min(\n",
    "            input_tensor=logits_masked, axis=1, keepdims=True\n",
    "        )\n",
    "        return tf.where(\n",
    "            logits < min_logits,\n",
    "            tf.ones_like(logits, dtype=logits.dtype) * -1e10,\n",
    "            logits,\n",
    "        )\n",
    "\n",
    "\n",
    "def sample_sequence(\n",
    "    hparams,\n",
    "    length,\n",
    "    start_token=None,\n",
    "    batch_size=None,\n",
    "    context=None,\n",
    "    temperature=1,\n",
    "    top_k=0,\n",
    "    top_p=0.0,\n",
    "):\n",
    "    if start_token is None:\n",
    "        assert (\n",
    "            context is not None\n",
    "        ), 'Specify exactly one of start_token and context!'\n",
    "    else:\n",
    "        assert (\n",
    "            context is None\n",
    "        ), 'Specify exactly one of start_token and context!'\n",
    "        context = tf.fill([batch_size, 1], start_token)\n",
    "\n",
    "    def step(hparams, tokens, past=None):\n",
    "        lm_output = gpt2_model.model(\n",
    "            hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE\n",
    "        )\n",
    "\n",
    "        logits = lm_output['logits'][:, :, : hparams.n_vocab]\n",
    "        presents = lm_output['present']\n",
    "        presents.set_shape(\n",
    "            gpt2_model.past_shape(hparams=hparams, batch_size=None)\n",
    "        )\n",
    "        return {'logits': logits, 'presents': presents}\n",
    "\n",
    "    with tf.name_scope('sample_sequence'):\n",
    "        lens = tf.constant(0, dtype=tf.int32)\n",
    "        context_output = step(hparams, context[:, :-1])\n",
    "        \n",
    "        def apply_temp(logits_BxN, temperature):\n",
    "            logits_shape = tf.shape(logits_BxN)\n",
    "            uniform_noise_BxN = tf.random_uniform(logits_shape)\n",
    "            logits_BxN += -tf.log(-tf.log(uniform_noise_BxN)) * temperature\n",
    "            return logits_BxN\n",
    "\n",
    "        def body(past, prev, output, lens):\n",
    "            next_outputs = step(hparams, prev[:, tf.newaxis], past=past)\n",
    "            logits = next_outputs['logits'][:, -1, :]  \n",
    "            logits = tf.cond(\n",
    "                temperature > 0,\n",
    "                lambda: apply_temp(logits, temperature),\n",
    "                lambda: logits,\n",
    "            )\n",
    "            logits = tf.cond(top_p > 0.0, lambda: top_p_logits(logits, p=top_p),\n",
    "                             lambda: top_k_logits(logits, k=top_k))\n",
    "            samples = tf.random.categorical(\n",
    "                logits, num_samples=1, dtype=tf.int32\n",
    "            )\n",
    "            return [\n",
    "                tf.concat([past, next_outputs['presents']], axis=-2),\n",
    "                tf.squeeze(samples, axis=[1]),\n",
    "                tf.concat([output, samples], axis=1),\n",
    "                lens + 1\n",
    "            ]\n",
    "\n",
    "        def cond(past, prev, output, lens):\n",
    "            return tf.less(lens, length)\n",
    "\n",
    "        _, _, tokens, _ = tf.while_loop(\n",
    "            cond=cond,\n",
    "            body=body,\n",
    "            loop_vars=[context_output['presents'], context[:, -1], context, lens],\n",
    "            shape_invariants=[\n",
    "                tf.TensorShape(\n",
    "                    gpt2_model.past_shape(\n",
    "                        hparams=hparams, batch_size=None\n",
    "                    )\n",
    "                ),\n",
    "                tf.TensorShape([None]),\n",
    "                tf.TensorShape([None, None]),\n",
    "                lens.get_shape(),\n",
    "            ],\n",
    "            back_prop=False,\n",
    "        )\n",
    "\n",
    "        return tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d1191da",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(\n",
    "        self, hparams, encoder, **kwargs\n",
    "    ):\n",
    "        self._encoder = encoder\n",
    "        self._X = tf.placeholder(tf.int32, [1, None], name = 'X')\n",
    "        self._temperature = tf.placeholder(tf.float32, None, name = 'temp')\n",
    "        self._top_k = tf.placeholder(tf.int32, None, name = 'top_k')\n",
    "        self._top_p = tf.placeholder(tf.float32, None, name = 'top_p')\n",
    "        self._maxlen = tf.placeholder(tf.int32, None, name = 'maxlen')\n",
    "        self._n_samples = tf.placeholder(tf.int32, None, name = 'n_samples')\n",
    "        x = tf.tile(self._X, [self._n_samples, 1])\n",
    "        self._model = sample_sequence(\n",
    "            hparams=hparams,\n",
    "            length=self._maxlen,\n",
    "            context=x,\n",
    "            batch_size=self._n_samples,\n",
    "            temperature=self._temperature,\n",
    "            top_k=self._top_k,\n",
    "            top_p=self._top_p,\n",
    "        )\n",
    "        self.output = tf.identity(self._model, name = 'output')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c4a7b0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Model(\n",
    "    hparams, enc_malay\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd5bfbec",
   "metadata": {},
   "outputs": [],
   "source": [
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f751b579",
   "metadata": {},
   "outputs": [],
   "source": [
    "var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)\n",
    "saver = tf.train.Saver(var_list = var_list)\n",
    "saver.restore(sess, 'gs://mesolitica-tpu-general/gpt2-117m/model.ckpt-435300')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2442dcff",
   "metadata": {},
   "outputs": [],
   "source": [
    "string = 'mahathir dan najib razak sangat sayangkan anwar ibrahim'\n",
    "encoded = enc_malay.encode(string)\n",
    "len(encoded)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "542f0385",
   "metadata": {},
   "outputs": [],
   "source": [
    "o = sess.run(model._model, feed_dict = {model._X: [encoded],\n",
    "                                  model._temperature: 0.0,\n",
    "                                  model._top_k: 0,\n",
    "                                  model._top_p: 0.7,\n",
    "                                  model._maxlen: 20,\n",
    "                                  model._n_samples: 10})\n",
    "o.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00eec77b",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(o.shape[0]):\n",
    "    print(i, enc_malay.decode(o[i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ec9b89d",
   "metadata": {},
   "outputs": [],
   "source": [
    "saver = tf.train.Saver()\n",
    "saver.save(sess, 'gpt2-117m/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71b58d61",
   "metadata": {},
   "outputs": [],
   "source": [
    "strings = ','.join(\n",
    "    [\n",
    "        n.name\n",
    "        for n in tf.get_default_graph().as_graph_def().node\n",
    "        if ('Variable' in n.op\n",
    "        or 'gather' in n.op.lower()\n",
    "        or 'X' in n.name\n",
    "        or 'temp' in n.name\n",
    "        or 'top_' in n.name\n",
    "        or 'maxlen' in n.name\n",
    "        or 'n_samples' in n.name\n",
    "        or 'output' in n.name)\n",
    "        and 'adam' not in n.name\n",
    "        and 'global_step' not in n.name\n",
    "        and 'Assign' not in n.name\n",
    "        and 'ReadVariableOp' not in n.name\n",
    "        and 'Gather' not in n.name\n",
    "    ]\n",
    ")\n",
    "strings.split(',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e458886c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def freeze_graph(model_dir, output_node_names):\n",
    "\n",
    "    if not tf.gfile.Exists(model_dir):\n",
    "        raise AssertionError(\n",
    "            \"Export directory doesn't exists. Please specify an export \"\n",
    "            'directory: %s' % model_dir\n",
    "        )\n",
    "\n",
    "    checkpoint = tf.train.get_checkpoint_state(model_dir)\n",
    "    input_checkpoint = checkpoint.model_checkpoint_path\n",
    "\n",
    "    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])\n",
    "    output_graph = absolute_model_dir + '/frozen_model.pb'\n",
    "    clear_devices = True\n",
    "    with tf.Session(graph = tf.Graph()) as sess:\n",
    "        saver = tf.train.import_meta_graph(\n",
    "            input_checkpoint + '.meta', clear_devices = clear_devices\n",
    "        )\n",
    "        saver.restore(sess, input_checkpoint)\n",
    "        output_graph_def = tf.graph_util.convert_variables_to_constants(\n",
    "            sess,\n",
    "            tf.get_default_graph().as_graph_def(),\n",
    "            output_node_names.split(','),\n",
    "        )\n",
    "        with tf.gfile.GFile(output_graph, 'wb') as f:\n",
    "            f.write(output_graph_def.SerializeToString())\n",
    "        print('%d ops in the final graph.' % len(output_graph_def.node))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41f38c30",
   "metadata": {},
   "outputs": [],
   "source": [
    "freeze_graph('gpt2-117m', strings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffeac81e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_graph(frozen_graph_filename):\n",
    "    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:\n",
    "        graph_def = tf.GraphDef()\n",
    "        graph_def.ParseFromString(f.read())\n",
    "                \n",
    "    with tf.Graph().as_default() as graph:\n",
    "        tf.import_graph_def(graph_def)\n",
    "        \n",
    "    return graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cbe978b",
   "metadata": {},
   "outputs": [],
   "source": [
    "g = load_graph('gpt2-117m/frozen_model.pb')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0ec52de",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_nodes = ['X', 'temp', 'top_k', 'top_p', 'maxlen', 'n_samples']\n",
    "output_nodes = ['output']\n",
    "inputs = {n: g.get_tensor_by_name(f'import/{n}:0') for n in input_nodes}\n",
    "outputs = {n: g.get_tensor_by_name(f'import/{n}:0') for n in output_nodes}\n",
    "inputs, outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9893e8ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sess = tf.Session(graph = g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2ac02e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "o = test_sess.run(outputs['output'], feed_dict = {inputs['X']: [encoded],\n",
    "                                  inputs['temp']: 0.0,\n",
    "                                  inputs['top_k']: 40,\n",
    "                                  inputs['top_p']: 0.0,\n",
    "                                  inputs['maxlen']: 100,\n",
    "                                  inputs['n_samples']: 1})\n",
    "o.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d6b4455",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(enc_malay.decode(o[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7623f5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.tools.graph_transforms import TransformGraph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3075100b",
   "metadata": {},
   "outputs": [],
   "source": [
    "transforms = ['add_default_attributes',\n",
    "             'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',\n",
    "             'fold_batch_norms',\n",
    "             'fold_old_batch_norms',\n",
    "             'quantize_weights(fallback_min=-10, fallback_max=10)',\n",
    "             'strip_unused_nodes',\n",
    "             'sort_by_execution_order']\n",
    "\n",
    "input_nodes = ['X', 'temp', 'top_k', 'top_p', 'maxlen', 'n_samples']\n",
    "output_nodes = ['output']\n",
    "\n",
    "pb = 'gpt2-117m/frozen_model.pb'\n",
    "\n",
    "input_graph_def = tf.GraphDef()\n",
    "with tf.gfile.FastGFile(pb, 'rb') as f:\n",
    "    input_graph_def.ParseFromString(f.read())\n",
    "\n",
    "transformed_graph_def = TransformGraph(input_graph_def, \n",
    "                                           input_nodes,\n",
    "                                           output_nodes, transforms)\n",
    "    \n",
    "with tf.gfile.GFile(f'{pb}.quantized', 'wb') as f:\n",
    "    f.write(transformed_graph_def.SerializeToString())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d91b8bbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "g = load_graph('gpt2-117m/frozen_model.pb.quantized')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f0b7feb",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_nodes = ['X', 'temp', 'top_k', 'top_p', 'maxlen', 'n_samples']\n",
    "output_nodes = ['output']\n",
    "inputs = {n: g.get_tensor_by_name(f'import/{n}:0') for n in input_nodes}\n",
    "outputs = {n: g.get_tensor_by_name(f'import/{n}:0') for n in output_nodes}\n",
    "inputs, outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea326f8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sess = tf.Session(graph = g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "638c68cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# o = test_sess.run(outputs['output'], feed_dict = {inputs['X']: [encoded],\n",
    "#                                   inputs['temp']: 0.0,\n",
    "#                                   inputs['top_k']: 40,\n",
    "#                                   inputs['top_p']: 0.0,\n",
    "#                                   inputs['maxlen']: 100,\n",
    "#                                   inputs['n_samples']: 1})\n",
    "# o.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37047931",
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(enc_malay.decode(o[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6949d9d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from b2sdk.v1 import *\n",
    "info = InMemoryAccountInfo()\n",
    "b2_api = B2Api(info)\n",
    "b2_api.authorize_account(\"production\", application_key_id, application_key)\n",
    "file_info = {'how': 'good-file'}\n",
    "b2_bucket = b2_api.get_bucket_by_name('malaya-model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "038a1eb2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FileVersionInfo('4_zcde33cc461767caf742c0b11_f201775d542477cf3_d20210923_m090906_c000_v0001400_t0050', 'gpt2/117M/model.pb', 498708685, 'application/octet-stream', 'none', {'how': 'good-file'}, 1632388146000, <EncryptionSetting(EncryptionMode.NONE, None, None)>, <LegalHold.UNSET: None>, FileRetentionSetting(None, None), 1632388146000, None, None, None, 'upload', <b2sdk.v1.api.B2Api object at 0x7f26144e2dd8>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "file = 'gpt2-117m/frozen_model.pb'\n",
    "outPutname = 'gpt2/117M/model.pb'\n",
    "b2_bucket.upload_local_file(\n",
    "    local_file=file,\n",
    "    file_name=outPutname,\n",
    "    file_infos=file_info,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e9a702b8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FileVersionInfo('4_zcde33cc461767caf742c0b11_f202a1847f9337d3b_d20210923_m090925_c000_v0001079_t0000', 'gpt2/117M-quantized/model.pb', 125564697, 'application/octet-stream', 'none', {'how': 'good-file'}, 1632388165000, <EncryptionSetting(EncryptionMode.NONE, None, None)>, <LegalHold.UNSET: None>, FileRetentionSetting(None, None), 1632388165000, None, None, None, 'upload', <b2sdk.v1.api.B2Api object at 0x7f26144e2dd8>)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "file = 'gpt2-117m/frozen_model.pb.quantized'\n",
    "outPutname = 'gpt2/117M-quantized/model.pb'\n",
    "b2_bucket.upload_local_file(\n",
    "    local_file=file,\n",
    "    file_name=outPutname,\n",
    "    file_infos=file_info,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52090f9b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
